from abc import ABC, abstractmethod
import torch.nn as nn

class BaseLossPlugin(ABC):
    """损失函数插件基类"""
    
    @abstractmethod
    def create_loss_function(self, config):
        """创建损失函数实例"""
        pass
        
    @abstractmethod
    def compute_loss(self, logits, labels, loss_mask=None):
        """计算损失"""
        pass